Tree-based estimators

Tree-based estimators (see the sklearn.tree module and forest of trees in the sklearn.ensemble module) can be used to compute feature importances, which in turn can be used to discard irrelevant features (when coupled with the sklearn.feature_selection.SelectFromModel meta-transformer):


In [19]:
%matplotlib inline  
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris

from sklearn.ensemble import ExtraTreesClassifier
from sklearn.datasets import load_iris
from sklearn.feature_selection import SelectFromModel

import matplotlib.pyplot as plt

iris = load_iris()

In [20]:
data = pd.DataFrame(data=np.c_[iris['data']],
                     columns=iris['feature_names'])

target = pd.DataFrame(data=np.c_[iris['target']],columns=['target'])

X, y = data, target

In [21]:
# Build a forest and compute the feature importances
forest = ExtraTreesClassifier(n_estimators=250,
                              random_state=0)

forest.fit(X, y)
importances = forest.feature_importances_

std = np.std([tree.feature_importances_ for tree in forest.estimators_],
             axis=0)
indices = np.argsort(importances)[::-1]

# Print the feature ranking
print("Feature ranking:")

for f in range(X.shape[1]):
    print("%d. feature %d [%s] (%f) " % (f + 1, indices[f], X.columns[f], importances[indices[f]]))


C:\Anaconda3\lib\site-packages\ipykernel\__main__.py:5: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().
Feature ranking:
1. feature 3 [sepal length (cm)] (0.436836) 
2. feature 2 [sepal width (cm)] (0.408526) 
3. feature 0 [petal length (cm)] (0.090051) 
4. feature 1 [petal width (cm)] (0.064587) 

In [22]:
# Plot the feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(X.shape[1]), importances[indices],
       color="r", yerr=std[indices], align="center")
plt.xticks(range(X.shape[1]), X.columns, rotation='vertical')
plt.xlim([-1, X.shape[1]])
plt.show()



In [ ]:


In [ ]: